{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42145ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def plot_preds(train, test, pred_dict, model_name, show_samples=False):\n",
    "    pred = pred_dict['median']\n",
    "    pred = pd.Series(pred, index=test.index)\n",
    "    plt.figure(figsize=(8, 6), dpi=100)\n",
    "    plt.plot(train, color='black')\n",
    "    plt.plot(test, label='Truth', color='black')\n",
    "    plt.plot(pred, label=model_name, color='purple')\n",
    "    # shade 90% confidence interval\n",
    "    samples = pred_dict['samples']\n",
    "    lower = np.quantile(samples, 0.05, axis=0)\n",
    "    upper = np.quantile(samples, 0.95, axis=0)\n",
    "    plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple')\n",
    "    if show_samples:\n",
    "        samples = pred_dict['samples']\n",
    "        # convert df to numpy array\n",
    "        samples = samples.values if isinstance(samples, pd.DataFrame) else samples\n",
    "        for i in range(min(10, samples.shape[0])):\n",
    "            plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1)\n",
    "    plt.legend(loc='upper left')\n",
    "    if 'NLL/D' in pred_dict:\n",
    "        nll = pred_dict['NLL/D']\n",
    "        if nll is not None:\n",
    "            plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67ef0416",
   "metadata": {},
   "source": [
    "## Darts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd9d037",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.small_context import get_datasets\n",
    "\n",
    "output_dir = 'precomputed_outputs/darts'\n",
    "datasets = get_datasets()\n",
    "for ds_name, data in datasets.items():\n",
    "    print(ds_name)\n",
    "    data = datasets[ds_name]\n",
    "    train, test = data\n",
    "    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:\n",
    "        out = pickle.load(f)\n",
    "    for model in out:\n",
    "        plot_preds(train, test, out[model], model, show_samples=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a33dde27",
   "metadata": {},
   "source": [
    "## Synthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6df724",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.synthetic import get_synthetic_datasets\n",
    "output_dir = 'precomputed_outputs/synthetic'\n",
    "datasets = get_synthetic_datasets()\n",
    "for ds_name, data in datasets.items():\n",
    "    print(ds_name)\n",
    "    data = datasets[ds_name]\n",
    "    train, test = data\n",
    "    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:\n",
    "        out = pickle.load(f)\n",
    "    for model in out:\n",
    "        plot_preds(train, test, out[model], model, show_samples=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf162a91",
   "metadata": {},
   "source": [
    "## Monash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6ff231f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_monash_preds(train, test, pred_dict, model_name, max_series):\n",
    "    for i in range(min(max_series, len(test))):\n",
    "        pred = pd.Series(pred_dict['median'][i], index=test[i].index)\n",
    "        plt.figure(figsize=(8, 6), dpi=100)\n",
    "        plt.plot(train[i], color='black')\n",
    "        plt.plot(test[i], label='Truth', color='black')\n",
    "        plt.plot(pred, label=model_name, color='purple')\n",
    "        plt.legend(loc='upper left')\n",
    "        ymax = max(train[i].max(), test[i].max()) * 1.1\n",
    "        ymin = plt.gca().get_ylim()[0]\n",
    "        plt.ylim(ymin, ymax)\n",
    "        if 'NLL/D' in pred_dict:\n",
    "            nll = pred_dict['NLL/D'][i]\n",
    "            if nll is not None:\n",
    "                plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5))\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30c34822",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.monash import get_datasets\n",
    "output_dir = 'precomputed_outputs/monash'\n",
    "max_history_len = 500\n",
    "datasets = get_datasets()\n",
    "for ds_name, data in datasets.items():\n",
    "    if not os.path.exists(f'{output_dir}/{ds_name}.pkl'):\n",
    "        continue\n",
    "    print(ds_name)\n",
    "    data = datasets[ds_name]\n",
    "    train, test = data\n",
    "    train = [x[-max_history_len:] for x in train]\n",
    "    # turn into pd series\n",
    "    train = [pd.Series(train[i], index=pd.RangeIndex(len(train[i]))) for i in range(len(train))]\n",
    "    test = [pd.Series(test[i], index=pd.RangeIndex(len(train[i]), len(train[i]) + len(test[i]))) for i in range(len(test))]\n",
    "    \n",
    "    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:\n",
    "        out = pickle.load(f)\n",
    "    for model in out:\n",
    "        plot_monash_preds(train, test, out[model], model, max_series=3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fspace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "9436057e92285046d415c34e216bd357b01decd87fa7e06f42744a4b160880c3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
